
#ifndef parameters_sweeps_h
#define parameters_sweeps_h
#include "system_class.h"
#include <functional>
#include <cmath>

//Enums for sweep variables and function types
enum SweepVariables{InitialParticleNumber, ActivationRate, DimerizationRate, DecayRate};
enum SubcriticalFunctionType{NonFunction,Constant,Linear,Exponential};


/// Linear function for different decay rates
/// @param critical_size critical particle size
/// @param index current site
/// @param dimerization dimerization rate
double LinearFunction(int critical_size,int index,double dimerization){
    if (index) {
        double step=(1.0-dimerization)/static_cast<double>(critical_size-1);
        return dimerization+step*static_cast<double>(index);
    }
    return dimerization;
}

/// Expinential function for different decay rates
/// @param critical_size critical particle size
/// @param index current site
/// @param dimerization dimerization rate
const double ExponentialFunction(const int critical_size,const int index,const double dimerization){
    if (index) {
        return dimerization*exp(-log(dimerization)*static_cast<double>(index)/static_cast<double>(critical_size-1));
    }
    return dimerization;
}

/// Identity function (dummy)for different decay rates
/// @param critical_size critical particle size
/// @param index current site
/// @param dimerization dimerization rate
const double Identity(const int critical_size,const int index,const double dimerization){
    return dimerization;
}

//
std::vector<double>* data_which_can_be_rescued=nullptr;
H5::H5File* time_traces_which_can_be_rescued=nullptr;


template <int ring_size>
/// class to perform systematic parameter sweeps
class ParameterSweep{
    
public:
    
    /// Constructor
    /// @param initial_particle_number initial number of inactive particles
    /// @param activation_rate activation rat
    /// @param dimerization_rate dimerizaion rate
    /// @param decay_rate decay rate
    /// @param reaction_rates array with all reaction rates
    /// @param record_time_trajectories record intermediate states?
    /// @param time_steps steps for recording
    /// @param seed random number seed
    ParameterSweep(const long initial_particle_number,const double activation_rate,const double dimerization_rate,const double decay_rate,std::array<double, ring_size-1> reaction_rates,const bool record_time_trajectories,const double time_steps,const long seed):
    rng_engine_(seed),
    reaction_rates_(reaction_rates),
    record_time_trajectories_{record_time_trajectories},
    time_steps_{time_steps},
    non_trivial_activation_{false},
    group_index_{0}
    {
        min_values_[InitialParticleNumber]=initial_particle_number;
        min_values_[ActivationRate]=activation_rate;
        min_values_[DimerizationRate]=dimerization_rate;
        min_values_[DecayRate]=decay_rate;
        
        for (int i=0; i<4; ++i) {
            max_values_[i]=min_values_[i];
            delta_values_[i]=2.0;
        }
    }
    
    /// Set the range for a parameter sweep
    /// @param sweep_variable enum of the variable
    /// @param min minimal value
    /// @param max maximal value
    /// @param delta steps
    void SetSweepRange(SweepVariables sweep_variable,double min,double max, double delta){
        if (max<min||delta<=1||min<=0) {
            std::cout<<"Not a valid parameter sweep\n";
            return;
        }
        min_values_[sweep_variable]=min;
        max_values_[sweep_variable]=max;
        delta_values_[sweep_variable]=delta;
    }
    
    /// Choose function for decay rates
    /// @param type enum for function type
    void SetSubcriticalFunction(SubcriticalFunctionType type){
        switch (type) {
            case Constant:
                activation_function_=Identity;
                non_trivial_activation_=true;
                break;
            case Linear:
                activation_function_=LinearFunction;
                non_trivial_activation_=true;
                break;
            case Exponential:
                activation_function_=ExponentialFunction;
                non_trivial_activation_=true;
                break;
            default:
                non_trivial_activation_=false;
                break;
        }
        
    }
    
    //
    template<int number_of_species,int critical_size>
    /// perform parameter sweep
    /// @param number_of_ens number of ensembles
    /// @param sweep_data vector for results return
    /// @param file_time_trace file for additional data
    void Sweep(int number_of_ens,std::vector<double>& sweep_data,H5::H5File& file_time_trace){
        auto random_init=std::bind(std::uniform_int_distribution<long>(0,~0u), std::ref(rng_engine_));
        for (long tmp_particle_number=static_cast<long>(min_values_[InitialParticleNumber]); tmp_particle_number<static_cast<long>(max_values_[InitialParticleNumber])+1; tmp_particle_number*=static_cast<long>(delta_values_[InitialParticleNumber])) {
            for (double tmp_decay_rate=(min_values_[DecayRate]?min_values_[DecayRate]:0.001); tmp_decay_rate<1.00001*(max_values_[DecayRate]+0.001); tmp_decay_rate*=delta_values_[DecayRate]) {
                std::array<double, critical_size-2> decay_rates;
                for (auto&x: decay_rates) {
                    if (min_values_[DecayRate]) {
                        x=tmp_decay_rate;
                    }
                    else{
                        x=0;
                    }
                }
                for (double tmp_dimerization_rate=min_values_[DimerizationRate]; tmp_dimerization_rate<1.00001*max_values_[DimerizationRate]; tmp_dimerization_rate*=delta_values_[DimerizationRate]) {
                    if (non_trivial_activation_) {
                        for (int i=0; i<critical_size-1; ++i) {
                            reaction_rates_[i]=activation_function_(critical_size,i,tmp_dimerization_rate);
                        }
                    }
                    else{
                        reaction_rates_[0]=tmp_dimerization_rate;
                    }
                     for (double tmp_activation_rate=min_values_[ActivationRate]; tmp_activation_rate<1.00001*max_values_[ActivationRate]; tmp_activation_rate*=delta_values_[ActivationRate]) {
                        
                        RingStructure<ring_size, number_of_species, critical_size> sample(tmp_particle_number,tmp_activation_rate,reaction_rates_,decay_rates,record_time_trajectories_,time_steps_,random_init());
                         sweep_data.push_back(static_cast<double>(ring_size));
                         sweep_data.push_back(static_cast<double>(number_of_species));
                         sweep_data.push_back(static_cast<double>(critical_size));
                        sweep_data.push_back(static_cast<double>(tmp_particle_number));
                         
                         if (min_values_[DecayRate]) {
                             sweep_data.push_back(tmp_decay_rate);
                         }
                         else{
                             sweep_data.push_back(0);
                         }
                         sweep_data.push_back(tmp_dimerization_rate);
                         sweep_data.push_back(tmp_activation_rate);
                         sweep_data.push_back(sample.GetEnsAv(number_of_ens,"time_trace_"+std::to_string(group_index_),file_time_trace));
                         ++group_index_;
                    }
                }
            }
        }
    }
    
private:
    
    //variables for all sweap features
    bool non_trivial_activation_;
    bool record_time_trajectories_;
    double time_steps_;
    int group_index_;
    std::array<double, ring_size-1> reaction_rates_;
    std::mt19937 rng_engine_;
    std::array<double, 4> min_values_;
    std::array<double, 4> max_values_;
    std::array<double, 4> delta_values_;
    const double(*activation_function_)(const int,const int,const double) ;
};

/// Write 2D data to hdf5
/// @param source data source
/// @param number_of_arguments number of different arguments
/// @param name file name
void Write2DVectorToHDF5(std::vector<double>& source,int number_of_arguments,std::string name){
    hsize_t dims[2];
    dims[0]=source.size()/number_of_arguments;
    dims[1]=number_of_arguments;
    name.append(".hdf5");
    H5::H5File file(name.data(),H5F_ACC_TRUNC);
    H5::DataSpace dataspace(2,dims);
    H5::DataSet dataset=file.createDataSet("parameter_sweep", H5::PredType::NATIVE_DOUBLE, dataspace);
    dataset.write(source.data(), H5::PredType::NATIVE_DOUBLE);
    file.close();
    
}

/// function to call in case simulation gets terminated
/// @param type signal type
void StopBecauseOfException(int type){
    if (data_which_can_be_rescued) {
        Write2DVectorToHDF5(*data_which_can_be_rescued, 8, std::string("rescued_data_"+std::to_string(time(nullptr))));
    }
    if (time_traces_which_can_be_rescued) {
        time_traces_which_can_be_rescued->close();
    }
    exit(type);
}

#endif /* parameters_sweeps_h */
